In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go
In [2]:
train_images, train_labels, test_images, test_labels = mnist()
train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)
train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [34]:
def visualize_images(images_tensor, w=28, h=28, col_wrap=5):
img = images_tensor.reshape(-1, w, h)
fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=col_wrap)
item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]]))
fig.show()
In [4]:
net_parameters = {
'w0' : np.random.randn(256, 784) * 0.1,
'w1' : np.random.randn(256, 256) * 0.1,
'w2' : np.random.randn(256, 256) * 0.1,
'w3' : np.random.randn(10, 256) * 0.1,
}
In [6]:
def ReLU(x):
return jnp.maximum(0,x)
def forward(parameters, x):
x = x.T
x = parameters['w0'] @ x
x = ReLU(x)
x = parameters['w1'] @ x
x = ReLU(x)
x = parameters['w2'] @ x
x = ReLU(x)
x = parameters['w3'] @ x
x = x.T
return x
In [7]:
def loss(parameters, x, y):
out = forward(parameters, x)
out = jax.nn.softmax(out)
_loss = -(y * jnp.log(out)).sum(axis=-1).mean()
return _loss
loss(net_parameters, test_images, test_labels)
Out[7]:
Array(2.8375754, dtype=float32)
In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
Out[8]:
Array(0.10371667, dtype=float32)
In [9]:
grad_loss = jax.grad(loss)
lr = 0.1
# keep track of all the previous gradients
grad_history = []
for epoch in range(100):
p_grad = grad_loss(net_parameters, train_images, train_labels)
grad_history.append(p_grad)
net_parameters['w0'] -= lr * p_grad['w0']
net_parameters['w1'] -= lr * p_grad['w1']
net_parameters['w2'] -= lr * p_grad['w2']
net_parameters['w3'] -= lr * p_grad['w3']
print(f"epoch {epoch}")
print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
print(f"accuracy: {acc}")
print("\n")
epoch 0 validation loss: 2.9189021587371826 train loss: 2.8750860691070557 accuracy: 0.164000004529953 epoch 1 validation loss: 2.4470043182373047 train loss: 2.4396541118621826 accuracy: 0.2874833345413208 epoch 2 validation loss: 1.8561428785324097 train loss: 1.8763872385025024 accuracy: 0.3936833441257477 epoch 3 validation loss: 1.551874041557312 train loss: 1.5782337188720703 accuracy: 0.5454500317573547 epoch 4 validation loss: 1.3967901468276978 train loss: 1.4261633157730103 accuracy: 0.6118500232696533 epoch 5 validation loss: 1.271924376487732 train loss: 1.3043782711029053 accuracy: 0.6500666737556458 epoch 6 validation loss: 1.1636574268341064 train loss: 1.1976664066314697 accuracy: 0.6835333704948425 epoch 7 validation loss: 1.069806456565857 train loss: 1.1046808958053589 accuracy: 0.708383321762085 epoch 8 validation loss: 0.9889728426933289 train loss: 1.024280309677124 accuracy: 0.7311000227928162 epoch 9 validation loss: 0.919693112373352 train loss: 0.9550483226776123 accuracy: 0.7476666569709778 epoch 10 validation loss: 0.860313355922699 train loss: 0.895563542842865 accuracy: 0.7634833455085754 epoch 11 validation loss: 0.8095411658287048 train loss: 0.8444904685020447 accuracy: 0.7739666700363159 epoch 12 validation loss: 0.7667122483253479 train loss: 0.8009517788887024 accuracy: 0.786133348941803 epoch 13 validation loss: 0.7311145067214966 train loss: 0.7651796340942383 accuracy: 0.7906333208084106 epoch 14 validation loss: 0.7095828056335449 train loss: 0.7416921854019165 accuracy: 0.7956833243370056 epoch 15 validation loss: 0.718275249004364 train loss: 0.7508598566055298 accuracy: 0.7748500108718872 epoch 16 validation loss: 0.8224052786827087 train loss: 0.8473957180976868 accuracy: 0.7249833345413208 epoch 17 validation loss: 1.176560878753662 train loss: 1.199607253074646 accuracy: 0.6542666554450989 epoch 18 validation loss: 1.0818630456924438 train loss: 1.0991448163986206 accuracy: 0.6475833654403687 epoch 19 validation loss: 0.9507499933242798 train loss: 0.9748549461364746 accuracy: 0.6638333201408386 epoch 20 validation loss: 0.7726666927337646 train loss: 0.799832820892334 accuracy: 0.7476000189781189 epoch 21 validation loss: 0.6838322281837463 train loss: 0.7125795483589172 accuracy: 0.7737833261489868 epoch 22 validation loss: 0.6092954874038696 train loss: 0.6370028853416443 accuracy: 0.8159833550453186 epoch 23 validation loss: 0.5797497034072876 train loss: 0.6072186827659607 accuracy: 0.8192499876022339 epoch 24 validation loss: 0.5597652196884155 train loss: 0.5865647196769714 accuracy: 0.8320333361625671 epoch 25 validation loss: 0.5450906753540039 train loss: 0.5715097784996033 accuracy: 0.8308166861534119 epoch 26 validation loss: 0.5375725626945496 train loss: 0.5629200339317322 accuracy: 0.8347166776657104 epoch 27 validation loss: 0.5302454829216003 train loss: 0.5557253360748291 accuracy: 0.8318166732788086 epoch 28 validation loss: 0.5343443155288696 train loss: 0.5580121278762817 accuracy: 0.8284167051315308 epoch 29 validation loss: 0.5323984026908875 train loss: 0.5570401549339294 accuracy: 0.8246833682060242 epoch 30 validation loss: 0.5490544438362122 train loss: 0.5707834959030151 accuracy: 0.8147667050361633 epoch 31 validation loss: 0.5427020788192749 train loss: 0.5665341019630432 accuracy: 0.8155333399772644 epoch 32 validation loss: 0.559158980846405 train loss: 0.5792463421821594 accuracy: 0.8072666525840759 epoch 33 validation loss: 0.5375383496284485 train loss: 0.560635507106781 accuracy: 0.8168333172798157 epoch 34 validation loss: 0.537998616695404 train loss: 0.5572067499160767 accuracy: 0.8160666823387146 epoch 35 validation loss: 0.5104748606681824 train loss: 0.533051609992981 accuracy: 0.8287833333015442 epoch 36 validation loss: 0.5004280209541321 train loss: 0.518977165222168 accuracy: 0.8355666995048523 epoch 37 validation loss: 0.47629404067993164 train loss: 0.4982987940311432 accuracy: 0.8442500233650208 epoch 38 validation loss: 0.46580857038497925 train loss: 0.4837290644645691 accuracy: 0.8530333638191223 epoch 39 validation loss: 0.44755056500434875 train loss: 0.46879786252975464 accuracy: 0.857283353805542 epoch 40 validation loss: 0.43991610407829285 train loss: 0.4572533965110779 accuracy: 0.8648000359535217 epoch 41 validation loss: 0.4263867139816284 train loss: 0.44678544998168945 accuracy: 0.8656499981880188 epoch 42 validation loss: 0.4212598502635956 train loss: 0.4380277991294861 accuracy: 0.8720333576202393 epoch 43 validation loss: 0.41086313128471375 train loss: 0.4304296374320984 accuracy: 0.8713499903678894 epoch 44 validation loss: 0.4074323773384094 train loss: 0.42363497614860535 accuracy: 0.8773166537284851 epoch 45 validation loss: 0.3989733159542084 train loss: 0.41777488589286804 accuracy: 0.8760833144187927 epoch 46 validation loss: 0.3965945243835449 train loss: 0.4122386574745178 accuracy: 0.8813999891281128 epoch 47 validation loss: 0.38940292596817017 train loss: 0.40750280022621155 accuracy: 0.8791666626930237 epoch 48 validation loss: 0.38778334856033325 train loss: 0.40287551283836365 accuracy: 0.8844833374023438 epoch 49 validation loss: 0.3814713656902313 train loss: 0.39894166588783264 accuracy: 0.881933331489563 epoch 50 validation loss: 0.3804478049278259 train loss: 0.39499184489250183 accuracy: 0.8864499926567078 epoch 51 validation loss: 0.37479153275489807 train loss: 0.3916802704334259 accuracy: 0.8835833668708801 epoch 52 validation loss: 0.37427714467048645 train loss: 0.3882569670677185 accuracy: 0.8885666728019714 epoch 53 validation loss: 0.3691239655017853 train loss: 0.38548731803894043 accuracy: 0.8851000070571899 epoch 54 validation loss: 0.3691091239452362 train loss: 0.38252490758895874 accuracy: 0.8900166749954224 epoch 55 validation loss: 0.3643854260444641 train loss: 0.38027825951576233 accuracy: 0.8863833546638489 epoch 56 validation loss: 0.36489102244377136 train loss: 0.3777557909488678 accuracy: 0.8910499811172485 epoch 57 validation loss: 0.36056894063949585 train loss: 0.37602779269218445 accuracy: 0.887416660785675 epoch 58 validation loss: 0.36157673597335815 train loss: 0.3739010989665985 accuracy: 0.8919166922569275 epoch 59 validation loss: 0.35761409997940063 train loss: 0.3726806044578552 accuracy: 0.8880333304405212 epoch 60 validation loss: 0.3591335713863373 train loss: 0.37092408537864685 accuracy: 0.8924833536148071 epoch 61 validation loss: 0.35549429059028625 train loss: 0.3702082335948944 accuracy: 0.8884333372116089 epoch 62 validation loss: 0.35747775435447693 train loss: 0.36875709891319275 accuracy: 0.8926166892051697 epoch 63 validation loss: 0.3540874123573303 train loss: 0.36846068501472473 accuracy: 0.8882166743278503 epoch 64 validation loss: 0.356381356716156 train loss: 0.3671784996986389 accuracy: 0.8929499983787537 epoch 65 validation loss: 0.3530734181404114 train loss: 0.3671029210090637 accuracy: 0.8879666924476624 epoch 66 validation loss: 0.35538172721862793 train loss: 0.36573588848114014 accuracy: 0.8930000066757202 epoch 67 validation loss: 0.3518384099006653 train loss: 0.36550372838974 accuracy: 0.8879833221435547 epoch 68 validation loss: 0.3537236750125885 train loss: 0.36370065808296204 accuracy: 0.8935333490371704 epoch 69 validation loss: 0.3496054708957672 train loss: 0.36286935210227966 accuracy: 0.8889333605766296 epoch 70 validation loss: 0.3506583273410797 train loss: 0.36033743619918823 accuracy: 0.8945333361625671 epoch 71 validation loss: 0.3457734286785126 train loss: 0.35858094692230225 accuracy: 0.8905333280563354 epoch 72 validation loss: 0.3457915782928467 train loss: 0.3552532494068146 accuracy: 0.8962500095367432 epoch 73 validation loss: 0.3401920199394226 train loss: 0.3524875044822693 accuracy: 0.8927666544914246 epoch 74 validation loss: 0.33920755982398987 train loss: 0.3485132157802582 accuracy: 0.8985666632652283 epoch 75 validation loss: 0.33331069350242615 train loss: 0.345060259103775 accuracy: 0.8959000110626221 epoch 76 validation loss: 0.33171555399894714 train loss: 0.34090983867645264 accuracy: 0.9013167023658752 epoch 77 validation loss: 0.3260519504547119 train loss: 0.3372683525085449 accuracy: 0.8986999988555908 epoch 78 validation loss: 0.32424861192703247 train loss: 0.33335527777671814 accuracy: 0.9036499857902527 epoch 79 validation loss: 0.31921058893203735 train loss: 0.32992836833000183 accuracy: 0.9015833139419556 epoch 80 validation loss: 0.31746235489845276 train loss: 0.3264819085597992 accuracy: 0.9054333567619324 epoch 81 validation loss: 0.31322750449180603 train loss: 0.323489248752594 accuracy: 0.9037500023841858 epoch 82 validation loss: 0.31165578961372375 train loss: 0.3205629289150238 accuracy: 0.9069333672523499 epoch 83 validation loss: 0.3081541359424591 train loss: 0.3179991543292999 accuracy: 0.906000018119812 epoch 84 validation loss: 0.3067646920681 train loss: 0.3155311942100525 accuracy: 0.9082000255584717 epoch 85 validation loss: 0.30387166142463684 train loss: 0.3133417069911957 accuracy: 0.907633364200592 epoch 86 validation loss: 0.30262669920921326 train loss: 0.3112283945083618 accuracy: 0.909250020980835 epoch 87 validation loss: 0.3001970052719116 train loss: 0.3093218207359314 accuracy: 0.9088833332061768 epoch 88 validation loss: 0.2990494668483734 train loss: 0.30746665596961975 accuracy: 0.9101166725158691 epoch 89 validation loss: 0.2969615161418915 train loss: 0.30576759576797485 accuracy: 0.9101666808128357 epoch 90 validation loss: 0.29586556553840637 train loss: 0.30409154295921326 accuracy: 0.9107666611671448 epoch 91 validation loss: 0.29403406381607056 train loss: 0.3025375008583069 accuracy: 0.9113500118255615 epoch 92 validation loss: 0.29297715425491333 train loss: 0.3010016977787018 accuracy: 0.9117500185966492 epoch 93 validation loss: 0.2913386821746826 train loss: 0.2995525300502777 accuracy: 0.9119666814804077 epoch 94 validation loss: 0.29030153155326843 train loss: 0.29811716079711914 accuracy: 0.9125833511352539 epoch 95 validation loss: 0.28880614042282104 train loss: 0.29674792289733887 accuracy: 0.9126999974250793 epoch 96 validation loss: 0.2877866327762604 train loss: 0.29539209604263306 accuracy: 0.9132333397865295 epoch 97 validation loss: 0.28640565276145935 train loss: 0.294091135263443 accuracy: 0.9133833646774292 epoch 98 validation loss: 0.2854039669036865 train loss: 0.292802095413208 accuracy: 0.9140333533287048 epoch 99 validation loss: 0.28412362933158875 train loss: 0.2915572226047516 accuracy: 0.9142833352088928
In [10]:
im = 0
visualize_images(test_images[im])
forward(net_parameters, test_images[im])
Out[10]:
Array([-0.6755133 , -4.4337187 , -0.31666017, 1.5255424 , -1.6595467 ,
-0.63005555, -2.952984 , 8.591555 , -0.7318695 , 1.9986991 ], dtype=float32)
In [11]:
# the magnitude of the gradient at each training step
grad_norms = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
for grad_vector in grad_history:
grad_norms['w0'].append(np.linalg.norm(grad_vector['w0'].flatten()))
grad_norms['w1'].append(np.linalg.norm(grad_vector['w1'].flatten()))
grad_norms['w2'].append(np.linalg.norm(grad_vector['w2'].flatten()))
grad_norms['w3'].append(np.linalg.norm(grad_vector['w3'].flatten()))
fig = px.line(grad_norms)
fig.show()
In [12]:
# for each training step, calculate the the angle between the current and previous vector
grad_cosines = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
grad_angles = {
'w0':[],
'w1':[],
'w2':[],
'w3':[]
}
for i in range(1,len(grad_history)):
for key in ['w0','w1','w2','w3']:
g_i = grad_history[i][key].flatten()
g_i_norm = g_i / np.linalg.norm(g_i)
g_im1 = grad_history[i-1][key].flatten()
g_im1_norm = g_im1 / np.linalg.norm(g_im1)
cos = g_i_norm @ g_im1_norm
grad_cosines[key].append(cos)
angle = np.degrees(np.arccos(cos))
grad_angles[key].append(angle)
fig_0 = px.line(grad_cosines, title="Cosine Between Each Previous Gradient")
fig_0.show()
fig_1 = px.line(grad_angles, title="Angle Between Each Previous Gradient")
fig_1.show()
In [45]:
# Here we are going to find the similarity between each gradient, and each other gradient (per weight)
for key in ['w0','w1','w2','w3']:
# The history of every gradient for this parameter during the training process
history = [gradient_dict[key] for gradient_dict in grad_history]
# convert from a list to an numpy array
history = np.array(history)
#print(history.shape) # should have a shape: (training_epochs, output_dim, input_dim)
training_epochs, output_dim, input_dim = history.shape
history = history.reshape(training_epochs, output_dim * input_dim)
# normalize the gradient vector for each time step
magnitudes = np.linalg.norm(history, axis=-1)
history = (history.T / magnitudes).T
# find the cosine of the angle between the gradient of each step, and each other step
similarity_matrix = history @ history.T
fig = px.imshow(similarity_matrix.reshape(100,100), title=f"Similarity Matrix for {key} Gradients")
fig.update_layout(
autosize=False,
width=800,
height=800,
margin=dict(
l=50,
r=50,
b=100,
t=100,
pad=4
),
)
fig.show()